##------------------------------------------------------------------------------
## Required library
##------------------------------------------------------------------------------
library(rstan)

#------ Use rstan locally on a multicore machine
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

#------ Load the main dataset
load("Master.RData")

#------ Set Treatment (TRT), Outcome (OUT), Mediators (M) and Covariates (X)
Data <- Master
OUT <- Data$PM.2.5
TRT <- Data$SO2.SC
M <- cbind(Data$SO2_Annual, Data$NOx_Annual, Data$CO2_Annual)
X <- cbind(Data$S_n_CR, Data$NumNOxControls, Data$Heat_Input/100000, Data$Barometric_Pressure, Data$Temperature,  Data$PctCapacity, Data$sulfur_Content, Data$Phase2_Indicator, Data$Operating_Time/1000)


dim.cov <- dim(X)[2] #<--------- Num. of Covariates

#------ Variables by treatments
x0 <- X[which(TRT==0),]
x1 <- X[which(TRT==1),]
  
y0 <- OUT[which(TRT==0)]
y1 <- OUT[which(TRT==1)]

m0 <- log(M[which(TRT==0),])
m1 <- log(M[which(TRT==1),])
    
n0 <- dim(x0)[1]
n1 <- dim(x1)[1]

#------ Maximum number of clusters
C <- 10

#------ Num. of Iterations
n.iteration <- 40000


#------ (Option) Fit generalized linear models to set hyper-parameters
lm.y0 <- lm(y0 ~ x0)
y0_precision <- 1/var(y0); mu0_sub_mean <- coef(lm.y0)[1]
lm.y1 <- lm(y1 ~ x1)
y1_precision <- 1/var(y1); mu1_sub_mean <- coef(lm.y1)[1]

lm.m0 <- lm(m0 ~ x0)
m0_precision <- diag(1/var(m0));  m_mu0_sub_mean <- coef(lm.m0)[1,]
lm.m1 <- lm(m1 ~ x1)
m1_precision <- diag(1/var(m1));  m_mu1_sub_mean <- coef(lm.m1)[1,]


#------ Stan data and model for fitting Y0 distribution
stan_data <- list(C = C, n = n0, out = y0, x = x0, cov = dim.cov, precision = y0_precision, mu_sub_mean = mu0_sub_mean)

fit.y0 <-stan(file = 'stan_code.stan', data = stan_data, pars = c("gamma0", "gamma1", "W", "psi"), iter = n.iteration, chain = 1, cores = getOption("mc.cores", 1L))

#------ Stan data and model for fitting Y1 distribution
stan_data <- list(C = C, n = n1, out = y1, x = x1, cov = dim.cov, precision = y1_precision, mu_sub_mean = mu1_sub_mean)

fit.y1 <-stan(file = 'stan_code.stan', data = stan_data, pars = c("gamma0", "gamma1", "W", "psi"), iter = n.iteration, chain = 1, cores = getOption("mc.cores", 1L))


#------ Stan data and model for fitting M11, M12, M13 distributions
for(T in 1:dim(M)[2]){

    stan_data <- list(C = C, n = n1, out = m1[,T], x = x1, cov = dim.cov, precision = m1_precision[T], mu_sub_mean = m_mu1_sub_mean[T])

    eval(parse(text=paste('fit.m1',T, ' <-stan(file = "stan_code.stan", data = stan_data, pars = c("gamma0", "gamma1", "W", "psi"), iter = n.iteration, chain = 1, cores = getOption("mc.cores", 1L))' ,sep="")))

}


#------ Stan data and model for fitting M01, M02, M03 distributions
for(T in 1:dim(M)[2]){
    
    stan_data <- list(C = C, n = n0, out = m0[,T], x = x0, cov = dim.cov, precision = m0_precision[T], mu_sub_mean = m_mu0_sub_mean[T])
    
    eval(parse(text=paste('fit.m0',T, ' <-stan(file = "stan_code.stan", data = stan_data, pars = c("gamma0", "gamma1", "W", "psi"), iter = n.iteration, chain = 1, cores = getOption("mc.cores", 1L))' ,sep="")))
    
}

save(fit.y0, fit.y1, fit.m01, fit.m02, fit.m03, fit.m11, fit.m12, fit.m13, x0, y0, m0, n0, x1, y1, m1, n1, file="MCMCsample.RData")

